from torchvision import datasets, transforms
import numpy as np
import torch


# MNIST dataset
train_data = datasets.MNIST(root='data', train=True,
                            transform=transforms.Compose([transforms.ToTensor(),
                                                          transforms.Normalize((0.1307,), (0.3081,))]),
                            download=True)
test_data = datasets.MNIST(root='data', train=False,
                           transform=transforms.Compose([transforms.ToTensor(),
                                                         transforms.Normalize((0.1307,), (0.3081,))]))
# EMNIST dataset
train_data_emnist = datasets.EMNIST(root='data', train=True,
                                    transform=transforms.Compose([transforms.ToTensor(),
                                                                  transforms.Normalize((0.1307,), (0.3081,))]),
                                    download=True, split='letters')
test_data_emnist = datasets.EMNIST(root='data', train=False,
                                   transform=transforms.Compose([transforms.ToTensor(),
                                                                 transforms.Normalize((0.1307,), (0.3081,))]),
                                   split='letters')


# Collect a dataset using random policy
def data_collection(noise_type, noise, num_data, device):

    D = []

    for i in range(num_data):
        # Sample a context
        context_num = np.random.randint(10)
        idx_context = train_data.targets == context_num
        context_set = train_data.data[idx_context]
        context = context_set[np.random.randint(len(context_set))].float().to(device)

        # Sample an action
        rand_action = np.random.randint(10)
        one_hot_action = torch.zeros((1, 10)).to(device)
        one_hot_action[0][rand_action] = 1

        # Compute the latent reward
        if context_num == rand_action:
            latent_reward = 1
        else:
            latent_reward = 0

        # Sample a feedback
        if np.random.random() >= noise:
            idx_feedback = train_data.targets == latent_reward
            feedback_set = train_data.data[idx_feedback]
        else:
            if noise_type == 'Action_Inclusive':
                idx_feedback = train_data.targets == (rand_action + 6 * latent_reward - 3) % 10
                feedback_set = train_data.data[idx_feedback]
            elif noise_type == 'Context_Inclusive':
                idx_feedback = train_data.targets == (context_num + 6 * latent_reward - 3) % 10
                feedback_set = train_data.data[idx_feedback]
            elif noise_type == 'Action_Context_Inclusive':
                idx_feedback = train_data.targets == (context_num + rand_action + 6 * latent_reward - 3) % 10
                feedback_set = train_data.data[idx_feedback]
            elif noise_type == 'Independent':
                if latent_reward == 1:
                    idx_feedback = train_data_emnist.targets == 20  # if correct, sample an image of 't' (TURE)
                else:
                    idx_feedback = train_data_emnist.targets == 6  # if wrong, sample an image of 'f' (FALSE)
                feedback_set = train_data_emnist.data[idx_feedback]
        feedback = feedback_set[np.random.randint(len(feedback_set))].float().to(device)
        D.append((context, one_hot_action, feedback, latent_reward, context_num, rand_action))
    return D


# Collect test dataset
def test_data_collection(noise_type, noise, num_data, device):
    D = []
    for _ in range(num_data):
        context_num = np.random.randint(10)
        idx_context = test_data.targets == context_num
        context_set = test_data.data[idx_context]
        context = context_set[np.random.randint(len(context_set))].float().to(device)

        rand_action = np.random.randint(10)
        one_hot_action = torch.zeros((1, 10)).to(device)
        one_hot_action[0][rand_action] = 1

        if context_num == rand_action:
            latent_reward = 1
        else:
            latent_reward = 0

        if np.random.random() >= noise:
            idx_feedback = test_data.targets == latent_reward
            feedback_set = test_data.data[idx_feedback]
        else:
            if noise_type == 'Action_Inclusive':
                idx_feedback = test_data.targets == (rand_action + 6 * latent_reward - 3) % 10
                feedback_set = test_data.data[idx_feedback]
            elif noise_type == 'Context_Inclusive':
                idx_feedback = test_data.targets == (context_num + 6 * latent_reward - 3) % 10
                feedback_set = test_data.data[idx_feedback]
            elif noise_type == 'Action_Context_Inclusive':
                idx_feedback = test_data.targets == (context_num + rand_action + 6 * latent_reward - 3) % 10
                feedback_set = test_data.data[idx_feedback]
            elif noise_type == 'Independent':
                if latent_reward == 1:
                    idx_feedback = test_data_emnist.targets == 20
                else:
                    idx_feedback = test_data_emnist.targets == 6
                feedback_set = test_data_emnist.data[idx_feedback]
        feedback = feedback_set[np.random.randint(len(feedback_set))].float().to(device)
        D.append((context, one_hot_action, feedback, latent_reward, context_num))
    return D
